name: squad_v2

resources:
  candidates:
  - {accelerators: T4:8}
  - {accelerators: V100:8}

workdir: ./examples/benchmark/transformers_qa

setup: |
  conda create -n hf python=3.8 -y
  conda activate hf

  # Install SkyCallback
  pip install "git+https://github.com/skypilot-org/skypilot.git#egg=sky-callback&subdirectory=sky/callbacks/"

  # User setup
  pip install transformers
  git clone https://github.com/huggingface/transformers.git
  cd transformers
  git checkout v4.20.0
  pip install -r examples/pytorch/question-answering/requirements.txt

  # Apply the patch to enable SkyCallback
  git apply ../callback.patch

run: |
  conda activate hf
  cd transformers/examples/pytorch/question-answering/
  python run_qa.py \
    --model_name_or_path bert-base-uncased \
    --dataset_name squad_v2 \
    --do_train \
    --do_eval \
    --per_device_train_batch_size 12 \
    --learning_rate 3e-5 \
    --num_train_epochs 2 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --version_2_with_negative \
    --output_dir outputs/
